package com.example.dependencyscanner.service;

import com.example.dependencyscanner.dao.VulnerabilityDao;
import com.example.dependencyscanner.model.DependencyInfo;
import com.example.dependencyscanner.model.DependencyRisk;
import com.example.dependencyscanner.model.Vulnerability;
import com.example.dependencyscanner.util.VersionRangeChecker;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import java.util.ArrayList;
import java.util.List;

/**
 * 漏洞匹配服务
 * 负责将依赖信息与漏洞数据库进行匹配
 * 支持本地数据库和在线扫描两种模式
 */
@Service
public class VulnerabilityMatcher {
    
    private static final Logger logger = LoggerFactory.getLogger(VulnerabilityMatcher.class);
    
    @Autowired
    private VulnerabilityDao vulnerabilityDao;
    
    @Autowired
    private OnlineVulnerabilityService onlineVulnerabilityService;
    
    @Value("${vulnerability.scan.mode:online}")
    private String scanMode;
    
    /**
     * 匹配依赖列表中的漏洞
     * 
     * @param dependencies 依赖列表
     * @return 风险列表
     */
    public List<DependencyRisk> matchVulnerabilities(List<DependencyInfo> dependencies) {
        logger.info("开始匹配漏洞，共 {} 个依赖需要检查，扫描模式: {}", dependencies.size(), scanMode);
        
        List<DependencyRisk> risks;
        
        if ("online".equalsIgnoreCase(scanMode)) {
            logger.info("使用在线漏洞扫描模式");
            risks = onlineVulnerabilityService.scanDependencies(dependencies);
        }/* else if ("local".equalsIgnoreCase(scanMode)) {
            logger.info("使用本地数据库扫描模式");
            risks = matchWithLocalDatabase(dependencies);
        } else if ("hybrid".equalsIgnoreCase(scanMode)) {
            logger.info("使用混合扫描模式（本地+在线）");
            risks = matchWithHybridMode(dependencies);
        } */else {
            logger.warn("未知的扫描模式: {}，使用默认在线模式", scanMode);
            risks = onlineVulnerabilityService.scanDependencies(dependencies);
        }
        
        logger.info("漏洞匹配完成，发现 {} 个风险依赖", risks.size());
        return risks;
    }
    
    /**
     * 使用本地数据库匹配
     */
    private List<DependencyRisk> matchWithLocalDatabase(List<DependencyInfo> dependencies) {
        List<DependencyRisk> risks = new ArrayList<>();
        List<Vulnerability> allVulnerabilities = vulnerabilityDao.findAll();
        
        logger.info("本地漏洞数据库中共有 {} 条漏洞记录", allVulnerabilities.size());
        
        for (DependencyInfo dependency : dependencies) {
            List<DependencyRisk> dependencyRisks = matchSingleDependency(dependency, allVulnerabilities);
            risks.addAll(dependencyRisks);
        }
        
        return risks;
    }
    
    /**
     * 使用混合模式匹配
     */
    private List<DependencyRisk> matchWithHybridMode(List<DependencyInfo> dependencies) {
        List<DependencyRisk> risks = new ArrayList<>();
        
        // 首先尝试本地数据库
        List<DependencyRisk> localRisks = matchWithLocalDatabase(dependencies);
        risks.addAll(localRisks);
        
        // 对于本地数据库没有发现漏洞的依赖，使用在线扫描
        List<DependencyInfo> unscannedDependencies = new ArrayList<>();
        for (DependencyInfo dependency : dependencies) {
            boolean foundInLocal = localRisks.stream()
                .anyMatch(risk -> risk.getGroupId().equals(dependency.getGroupId()) 
                    && risk.getArtifactId().equals(dependency.getArtifactId()));
            
            if (!foundInLocal) {
                unscannedDependencies.add(dependency);
            }
        }
        
        if (!unscannedDependencies.isEmpty()) {
            logger.info("对 {} 个依赖进行在线补充扫描", unscannedDependencies.size());
            List<DependencyRisk> onlineRisks = onlineVulnerabilityService.scanDependencies(unscannedDependencies);
            risks.addAll(onlineRisks);
        }
        
        return risks;
    }
    
    /**
     * 匹配单个依赖的漏洞
     * 
     * @param dependency 依赖信息
     * @param vulnerabilities 漏洞列表
     * @return 该依赖的风险列表
     */
    private List<DependencyRisk> matchSingleDependency(DependencyInfo dependency, List<Vulnerability> vulnerabilities) {
        List<DependencyRisk> risks = new ArrayList<>();
        
        for (Vulnerability vulnerability : vulnerabilities) {
            if (isMatch(dependency, vulnerability)) {
                if (isVersionVulnerable(dependency.getVersion(), vulnerability.getVulnerableVersions())) {
                    DependencyRisk risk = new DependencyRisk(dependency, vulnerability);
                    risks.add(risk);
                    
                    logger.debug("发现漏洞: {}:{}:{} -> {}", 
                               dependency.getGroupId(), 
                               dependency.getArtifactId(), 
                               dependency.getVersion(), 
                               vulnerability.getCve());
                }
            }
        }
        
        return risks;
    }
    
    /**
     * 检查依赖是否匹配漏洞记录
     * 
     * @param dependency 依赖信息
     * @param vulnerability 漏洞信息
     * @return 是否匹配
     */
    private boolean isMatch(DependencyInfo dependency, Vulnerability vulnerability) {
        // 精确匹配 groupId 和 artifactId
        if (dependency.getGroupId().equals(vulnerability.getGroupId()) && 
            dependency.getArtifactId().equals(vulnerability.getArtifactId())) {
            return true;
        }
        
        // 模糊匹配：处理一些特殊情况
        return isFuzzyMatch(dependency, vulnerability);
    }
    
    /**
     * 模糊匹配逻辑
     * 处理一些groupId或artifactId可能不完全一致的情况
     */
    private boolean isFuzzyMatch(DependencyInfo dependency, Vulnerability vulnerability) {
        String depGroupId = dependency.getGroupId().toLowerCase();
        String depArtifactId = dependency.getArtifactId().toLowerCase();
        String vulnGroupId = vulnerability.getGroupId().toLowerCase();
        String vulnArtifactId = vulnerability.getArtifactId().toLowerCase();
        
        // 处理 unknown groupId 的情况
        if ("unknown".equals(depGroupId)) {
            return depArtifactId.equals(vulnArtifactId);
        }
        
        // 处理 artifactId 包含关系
        if (depArtifactId.contains(vulnArtifactId) || vulnArtifactId.contains(depArtifactId)) {
            // 进一步检查 groupId 是否相关
            if (depGroupId.contains(vulnGroupId) || vulnGroupId.contains(depGroupId)) {
                return true;
            }
        }
        
        // 处理一些常见的命名变化
        return isCommonVariation(depGroupId, depArtifactId, vulnGroupId, vulnArtifactId);
    }
    
    /**
     * 检查是否为常见的命名变化
     */
    private boolean isCommonVariation(String depGroupId, String depArtifactId, 
                                     String vulnGroupId, String vulnArtifactId) {
        // Spring Boot 相关的变化
        if (depGroupId.startsWith("org.springframework") && vulnGroupId.startsWith("org.springframework")) {
            return depArtifactId.equals(vulnArtifactId);
        }
        
        // Jackson 相关的变化
        if (depGroupId.contains("jackson") && vulnGroupId.contains("jackson")) {
            return depArtifactId.equals(vulnArtifactId);
        }
        
        // Log4j 相关的变化
        if (depArtifactId.contains("log4j") && vulnArtifactId.contains("log4j")) {
            return true;
        }
        
        return false;
    }
    
    /**
     * 检查版本是否存在漏洞
     * 
     * @param version 当前版本
     * @param vulnerableVersions 漏洞版本范围
     * @return 是否存在漏洞
     */
    private boolean isVersionVulnerable(String version, String vulnerableVersions) {
        if ("unknown".equals(version) || vulnerableVersions == null) {
            return false;
        }
        
        return VersionRangeChecker.isVulnerable(version, vulnerableVersions);
    }
    
    /**
     * 根据风险等级排序
     * 
     * @param risks 风险列表
     * @return 排序后的风险列表
     */
    public List<DependencyRisk> sortByRiskLevel(List<DependencyRisk> risks) {
        risks.sort((r1, r2) -> {
            int priority1 = getRiskPriority(r1.getRiskLevel());
            int priority2 = getRiskPriority(r2.getRiskLevel());
            
            if (priority1 != priority2) {
                return Integer.compare(priority2, priority1); // 高风险在前
            }
            
            // 相同风险等级按 groupId:artifactId 排序
            String key1 = r1.getGroupId() + ":" + r1.getArtifactId();
            String key2 = r2.getGroupId() + ":" + r2.getArtifactId();
            return key1.compareTo(key2);
        });
        
        return risks;
    }
    
    /**
     * 获取风险等级的优先级数值
     */
    private int getRiskPriority(String riskLevel) {
        if (riskLevel == null) {
            return 0;
        }
        
        switch (riskLevel.toLowerCase()) {
            case "critical":
                return 4;
            case "high":
                return 3;
            case "medium":
                return 2;
            case "low":
                return 1;
            default:
                return 0;
        }
    }
    
    /**
     * 获取风险统计信息
     * 
     * @param risks 风险列表
     * @return 统计信息
     */
    public RiskStatistics getRiskStatistics(List<DependencyRisk> risks) {
        RiskStatistics stats = new RiskStatistics();
        
        for (DependencyRisk risk : risks) {
            String level = risk.getRiskLevel();
            if (level != null) {
                switch (level.toLowerCase()) {
                    case "critical":
                        stats.criticalCount++;
                        break;
                    case "high":
                        stats.highCount++;
                        break;
                    case "medium":
                        stats.mediumCount++;
                        break;
                    case "low":
                        stats.lowCount++;
                        break;
                }
            }
        }
        
        stats.totalCount = risks.size();
        return stats;
    }
    
    /**
     * 风险统计信息
     */
    public static class RiskStatistics {
        public int totalCount = 0;
        public int criticalCount = 0;
        public int highCount = 0;
        public int mediumCount = 0;
        public int lowCount = 0;
        
        @Override
        public String toString() {
            return String.format("总计: %d, 严重: %d, 高危: %d, 中危: %d, 低危: %d", 
                               totalCount, criticalCount, highCount, mediumCount, lowCount);
        }
    }
}